from typing import cast

from inspect_ai.solver import TaskState


def extract_code(block: str) -> str:
    """
    Extract the Python code from a code block by removing markdown-style code fences.

    This function removes the "```python" and "```" delimiters often used in markdown to denote
    Python code blocks, and strips any leading or trailing whitespace.

    Parameters:
    -----------
    block: str
        A string containing a Python code block with markdown-style code fences.

    Returns:
    --------
    str
        The extracted code with the code fences and surrounding whitespace removed.

    """
    return block.replace("```python", "").replace("```", "").strip()


def get_generated_code(state: TaskState) -> list[str]:
    """
    Extract the code generated by the assistant in response to each subproblem posed.

    Parameters:
    -----------
        state: TaskState

    Returns:
    -------
        list[str]
            Each element is the model's solution to a subproblem, preserving subproblem order.

    Raises:
    -------
    AssertionError
        If the number of assistant messages does not match the number of subproblems in the task metadata.

    """
    assistant_messages = [
        extract_code(cast(str, message.content))
        for message in state.messages
        if message.role == "assistant"
    ]
    assert len(assistant_messages) == len(state.metadata["sub_steps"])
    return assistant_messages


def subproblem_str_to_int(num: str) -> int:
    """
    Extract the subproblem number from a SciCode problem identifier string.

    SciCode subproblem numbers follow the format <main_problem_no>.<subproblem_no>, where both parts are
    stored as strings. This function converts the subproblem number to an integer.

    Parameters:
    -----------
    num: str
        A string representing the problem identifier, formatted as <main_problem_no>.<subproblem_no>.

    Returns:
    -------
    int
        The subproblem number as an integer.

    Example:
    -------
    >>> subproblem_str_to_int("77.1")
    1

    """
    return int(num.split(".")[1])
